Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111Separius wants to merge 5 commits intofeature/puzzletronfrom
Conversation
Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression. Changes: - Add modelopt/torch/puzzletron/bypass_distillation/ module with full training loop, stitched model factory, checkpoint management, and data classes - Integrate bypass as optional Step 3 in puzzletron.py and puzzletron_nas_plugin.py (pipeline progress counter updates to 9 steps when bypass is enabled) - Add HuggingFace auto-download and skip-if-exists logic to puzzletron_nas_plugin.py for all pipeline steps - Add normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss to sewing_kit/utils.py - Fix child_init.py: support list of pruning mixins; fix None override treated as "keep original value" instead of raising TypeCheckError - Fix dataset.py: graceful fallback when tokenizer has no chat_template (base models) - Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling Python files are copied alongside config.json (required for trust_remote_code checkpoints such as NemotronH) - Add create_train_dataloader to dataloaders.py - Add MoEChannelPruning to MlpInitMode enum - Add default pruning_mixins() to ModelDescriptor base class - Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks during subblock_attention bypass (based on block config) - Enable bypass in llama-3_1-8B_pruneffn_memory config; add example bypass/defaults.yaml - Update README with bypass documentation: when to use, time cost, sequential execution, W&B logging - Add unit tests for loss functions and distribution utilities - Add GPU integration tests for bypass (FFN pruning, KV compression, multi-config sweep, checkpoint validation) - Fix test_puzzletron.py assertion to handle variable GPU counts
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughAdds an optional bypass (blockwise local) distillation stage: new bypass package with stitched teacher–student factory, distributed training loop and checkpointing, model/pruning extensions, normalized-MSE losses, dataloader helper, example configs/docs, and unit/GPU tests; integrates bypass into Puzzletron control flow. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Launcher as "launch_bypass_distillation(hydra_cfg)"
participant Orchestrator as "run_bypassed_training(cfg)"
participant Factory as "stitched_model_factory()"
participant Data as "DataLoader / Teacher"
participant Trainer as "train()"
participant Checkpoint as "save_bypass_checkpoint()"
User->>Launcher: provide Hydra cfg (single or sweep)
Launcher->>Orchestrator: start run(s)
Orchestrator->>Data: load teacher model & dataloaders
Orchestrator->>Factory: build stitched teacher & student modules
Factory-->>Orchestrator: return stitched modules + descriptors
Orchestrator->>Trainer: start training loop
loop per iteration
Trainer->>Data: fetch batch
Trainer->>Trainer: teacher forward -> capture activations
Trainer->>Trainer: student forward -> compute per-block losses
Trainer->>Trainer: backward, grad scale/clip, optimizer step
Trainer->>Checkpoint: conditional save, write markers, symlink
Checkpoint-->>Trainer: sync / resume info
end
Trainer-->>Orchestrator: training complete
Orchestrator-->>Launcher: run finished
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)
236-245:⚠️ Potential issue | 🟡 MinorThe fallback printer still emits only rank-local values.
This branch now advertises
num_layers={total_layers}, but it still prints only the contents ofrank_{rank}.pthand is executed on rank 0 only. On multi-GPU runs the suggestedEXPECTED_PRUNING_VALUESsnippet will be incomplete.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The printer currently outputs only rank-local pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0 aggregates pruning data from all ranks before printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth files) into a global pruning_scores for each layer_name, compute the global score and channels (e.g., combine/average or gather channel indices across ranks) respecting total_layers, and then have rank 0 iterate over layer_names using the aggregated values when printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES snippet.modelopt/torch/puzzletron/pruning/pruning_utils.py (1)
40-47:⚠️ Potential issue | 🟠 Major
MoEChannelPruningis exposed before the init path supports it.
modelopt/torch/puzzletron/tools/bypassed_training/child_init.pynow branches on this enum and forwards it into_init_mlp_module(), but_init_mlp_module()still falls through toUnsupported mlp_init_modefor this value when expert widths change. Any config that selectsMoEChannelPruningwill fail during child initialization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still treats that case as unsupported; update the _init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that performs the correct initialization when expert widths change (e.g., adapt the weight/activation shapes by slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
804-806: This change makes explicitnullresets impossible.Treating
Noneas “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back toNone. If callers need both behaviors, use a sentinel for “no override” and reserveNonefor explicit clearing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 804 - 806, The current override function (override) treats item_overrides == None as "keep original", which prevents callers from explicitly clearing a value to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no override" and reserve None in item_overrides to mean "set to None"/clear the field, updating the override function to check against the sentinel (NO_OVERRIDE) instead of None and adjust any callers that construct overrides to use the sentinel when they mean "leave original".modelopt/torch/puzzletron/utils/data/dataset.py (1)
123-130: Keep role markers in the no-template fallback.Joining only
contentcollapsessystem/user/assistantturns into plain text, which changes the supervision for chat datasets. A lightweight fallback like"{role}: {content}"preserves the conversation structure without relying on a tokenizer template.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The fallback that builds sample when getattr(self.tokenizer, "chat_template", None) is None should preserve role markers instead of joining only message["content"]; update the else branch in dataset.py (the block that currently sets sample = "\n".join(m["content"] for m in sample)) to join messages using a lightweight role-prefixed format like "{role}: {content}" so conversation turns (system/user/assistant) are retained; keep using the same sample variable and ensure this mirrors the structure expected by downstream code that consumes apply_chat_template outputs.modelopt/torch/puzzletron/utils/parsing.py (1)
337-345: Don’t silently treat every NaN as a no-op block.This formatter now drops any NaN entry and can report
No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The current filtering silently drops any NaN in losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides diverging trainable blocks; instead, update the logic around losses_dict, best_steps_dict and best_values_dict so you only drop entries whose keys match known skipped block types (e.g., the explicit list of no-op block names like "Mamba"), and for any other NaN values emit a warning/error (via the existing logger) that a trainable block produced NaN rather than removing it; ensure best_steps_dict and best_values_dict are only pruned to match the filtered losses_dict after this selective filtering and warning behavior.examples/puzzletron/main.py (1)
154-167: Progress messages inrun_mip_onlyare hardcoded and inconsistent with the dynamic approach.The
run_full_puzzletronfunction now uses dynamic step counting (N = _total_steps(hydra_cfg)), butrun_mip_onlystill uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).Consider applying the same dynamic step count logic here for consistency.
♻️ Suggested fix
def run_mip_only(hydra_config_path: str): ... # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + N = _total_steps(hydra_cfg) # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)" ) sweep.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to compute the total steps like run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when formatting the progress messages instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls around the conditional that currently show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step increments are correct for both the sweep branch (sweep.run_mip_sweep) and the mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress displays consistently with _total_steps.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
548-556: Unused variablenum_trainable_params.The variable
num_trainable_paramsis computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.♻️ Proposed removal
assert "learning_rate" in cfg.training - num_trainable_params = sum( - p.requires_grad and submodule_name in p_name - for p_name, p in student_stitched_module.named_parameters() - if "dummy_param" not in p_name # exclude placeholder params - ) - # Do NOT enable dummy params: blocks with no real trainable parameters - # (e.g. Mamba blocks during an attention-only bypass run) should produce - # NaN loss so they are excluded from statistics — identical to the - # optimizer=None path in the training loop. student_module_parameters = {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 548 - 556, Remove the unused computation of num_trainable_params: delete the sum(...) assignment that iterates over student_stitched_module.named_parameters() checking p.requires_grad and submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding explanatory comment about dummy params if still relevant, but eliminate the dead variable and its associated needless iteration to avoid wasted computation and clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.
In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.
In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.
In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.
---
Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.
---
Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".
In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.
In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40
📒 Files selected for processing (27)
examples/puzzletron/README.mdexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/main.pymodelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/puzzletron.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/data/dataset.pymodelopt/torch/puzzletron/utils/parsing.pytests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yamltests/gpu/torch/puzzletron/test_bypass.pytests/gpu/torch/puzzletron/test_puzzletron.pytests/unit/torch/puzzletron/__init__.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_bypass_utils.py
| # If "latest" doesn't exist, look explicitly into directories with `*iter-*` | ||
| candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()] | ||
|
|
||
| if not candidate_dirs: | ||
| return None | ||
|
|
||
| def get_iter_num(dir_name): | ||
| match = re.search(r"iter-(\d+)", dir_name.name) | ||
| return int(match.group(1)) if match else 0 | ||
|
|
||
| checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True) | ||
| for latest_dir in checkpoint_dirs: | ||
| if (latest_dir / "saving_completed").exists(): | ||
| return str(latest_dir) |
There was a problem hiding this comment.
Include step_num when picking the latest checkpoint.
This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.
💡 Suggested fix
- def get_iter_num(dir_name):
- match = re.search(r"iter-(\d+)", dir_name.name)
- return int(match.group(1)) if match else 0
-
- checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+ def checkpoint_order(path: Path) -> tuple[int, int, float]:
+ match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+ if not match:
+ return (0, 0, path.stat().st_mtime)
+ return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+ checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
| num_workers: int = 0, | ||
| ) -> DataLoader: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n -C3 'class ConstantLengthDataset|def __iter__|get_worker_info|worker_info' \
modelopt/torch/puzzletron/utils/data/dataset.py \
modelopt/torch/puzzletron/utils/data/dataloaders.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1247
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the full __iter__ method implementation
echo "=== ConstantLengthDataset.__iter__() implementation ==="
sed -n '100,150p' modelopt/torch/puzzletron/utils/data/dataset.py
echo ""
echo "=== create_train_dataloader function signature and first 30 lines ==="
sed -n '80,120p' modelopt/torch/puzzletron/utils/data/dataloaders.py
echo ""
echo "=== Check for get_worker_info usage in dataset.py ==="
grep -n "get_worker_info" modelopt/torch/puzzletron/utils/data/dataset.py || echo "No get_worker_info found"
echo ""
echo "=== Check for any guards on num_workers in dataloaders.py ==="
grep -n "num_workers" modelopt/torch/puzzletron/utils/data/dataloaders.pyRepository: NVIDIA/Model-Optimizer
Length of output: 4020
num_workers is unsafe until ConstantLengthDataset is worker-sharded.
ConstantLengthDataset.__iter__() in modelopt/torch/puzzletron/utils/data/dataset.py walks self.dataset directly and never uses get_worker_info(). Passing num_workers > 0 to the DataLoader will cause each worker to read the same iterable stream, duplicating training samples silently.
Suggested guard
def create_train_dataloader(
bos_rate: float = 1.0,
num_workers: int = 0,
) -> DataLoader:
"""Create an infinite training DataLoader over ConstantLengthDataset."""
+ if num_workers != 0:
+ raise ValueError(
+ "create_train_dataloader currently requires num_workers=0 because "
+ "ConstantLengthDataset is not sharded across DataLoader workers."
+ )
+
if isinstance(dataset_path, str):
dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory)Also applies to: 114-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 89 - 90,
The DataLoader factory allows num_workers>0 while ConstantLengthDataset.__iter__
does not shard via get_worker_info(), causing duplicate samples; update the
dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
| if shuffle_seed is not None: | ||
| train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does Hugging Face datasets.IterableDataset.shuffle()accept akeep_in_memory keyword argument?
💡 Result:
No. datasets.IterableDataset.shuffle() only takes seed, generator, and buffer_size (no keep_in_memory kwarg), per both the API reference and the implementation. [1], [2]
Branch the shuffle call for streaming datasets.
IterableDataset.shuffle() does not accept keep_in_memory=True. When load_streaming_fn() is used (streaming datasets), line 99 will fail at runtime with an unexpected keyword argument error.
Suggested fix
train_data = dataset[dataset_name]
if shuffle_seed is not None:
- train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
+ if isinstance(train_data, datasets.IterableDataset):
+ train_data = train_data.shuffle(seed=shuffle_seed)
+ else:
+ train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 98 - 99,
The call to train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) fails for
streaming (Iterable) datasets because IterableDataset.shuffle() doesn't accept
keep_in_memory; update the code that checks shuffle_seed to detect streaming
datasets (e.g., via whatever marker load_streaming_fn sets or by checking
hasattr(train_data, "__iter__") vs __len__/isinstance of IterableDataset) and
branch: for non-streaming datasets call train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) as before, and for streaming/iterable datasets call
train_data.shuffle(seed=shuffle_seed) without keep_in_memory; ensure you modify
the block that references shuffle_seed and train_data.shuffle so runtime errors
are avoided when load_streaming_fn() returns a streaming dataset.
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds bypass distillation (blockwise local knowledge distillation) as an optional pipeline stage to puzzletron. Includes a full training loop, stitched model factory, checkpoint management, loss functions, data loader, configuration, and comprehensive tests. Also fixes bugs in child_init.py, dataset.py, and adds HF auto-download logic.
Issues Found:
-
[Duplicated Code]
normalized_mse_lossinsewing_kit/utils.py(diff lines 432-445) is an exact duplicate of the existing implementation inmodelopt/torch/puzzletron/tools/kd_model.py:32-41. The new code should import and reuse the existing function rather than redefining it. Thevectorwise_normalized_mse_lossandbatched_normalized_mse_lossvariants are new and fine, but they should build on the existing import. -
[Correctness / Security]
training_loop.py:675—AutoTokenizer.from_pretraineduses hardcodedtrust_remote_code=True. The variabletrust_remote_codeis already computed from the descriptor at line 648. This should usetrust_remote_code=trust_remote_codeinstead. (Flagged by pre-merge checks as well.) -
[Correctness / Security]
bypass_checkpoint_utils.py:85,99—torch.load()calls lackweights_only=True. The codebase convention (e.g.,checkpoint_utils.py:43,77) is to useweights_only=Truefor state dict loading. These calls load state dicts and optimizer states respectively, which are pure tensor data and should useweights_only=True. -
[Correctness]
training_loop.py— Theexcept Exception as eblock at the end ofrun_bypassed_training(around line 870) catches all exceptions and callssys.exit(1)for non-SystemExit exceptions. This swallows the actual exception type and prevents proper test framework error reporting. In GPU tests, a failing bypass run will produceSystemExit(1)instead of the real traceback. Consider re-raising or at least logging before exit. -
[Correctness]
stitched_model_factory.py:370-373— The lambda closures in the stitched module creation loop (adapter=lambda v: InputArgs(target=v)andadapter=lambda v: InputArgs(input=v)) capturevcorrectly since they're arguments, but the loss target/input naming ("target"and"input") relies onblock_loss_funcaccepting exactly these keyword arguments. If someone changesblock_loss_functo e.g.batched_normalized_mse_loss, the keyword args don't match (batched_normalized_mse_losstakesinputandtargetpositional args, not kwargs viaInputArgs). This coupling is implicit and fragile — consider documenting the contract or adding a**kwargsadapter. -
[Correctness]
bypass_checkpoint_utils.py:89—loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict}merges the current state dict with the loaded one (loaded takes precedence). However, the current state dict is fetched before loading — if the model is on a different device, keys may contain tensors on the wrong device. The subsequentload_state_dictshould handle this, but the intermediate merged dict is wasteful. Consider just usingstrict=Falsewithload_state_dictdirectly. -
[Readability]
stitched_model_factory.py— Thebypass_factory_fnfunction is ~250 lines long with deeply nested logic. The student model initialization block (lines 200-305) could be extracted into a helper like_initialize_student_model(...). -
[Readability]
training_loop.py— Thetrain()function is ~300 lines with deeply nested control flow for logging, validation, checkpoint saving, and time-based signals. Consider extracting checkpoint-save logic and logging logic into separate functions. -
[Readability]
stitched_model_factory.py:434-435— Blank lines between the closing of the function and the backward-compatible aliases (gqa_factory_fn = bypass_factory_fn,moe_factory_fn = bypass_factory_fn). These aliases have no callers in this PR and no documentation. If they're for backward compat with existing configs, add a comment. If they're unused, remove them. -
[Tests] The GPU tests are thorough for the happy path but don't test checkpoint resume (loading from a previous run). The
find_last_ckpt_for_resume+load_local_statepath is complex and untested. At minimum, a test that runs bypass, then runs it again withfind_last_ckpt_for_resume=Trueto verify resume works would increase confidence. -
[Tests] No unit test for
_set_keys_to_learnwhich has significant branching logic (subblock types, hybrid model block_configs filtering, regex fallback). This function is critical for correctness. -
[Correctness]
puzzletron_nas_plugin.py— The new auto-download logic inconvert_puzzletron_model(lines 152-165) runssnapshot_downloadonly on rank 0 insideif dist.is_master(), but then all ranks calldist.barrier(). If the download takes a long time, the barrier timeout (set inmain.pyastimedelta(10)= 10 days) should be fine, but theinput_model_pathvariable is only updated on rank 0 — other ranks never use it since only rank 0 does the conversion. This is correct but subtle; a comment would help. -
[Correctness]
bypass_utils.py:50—set_experiment_dirassigns aPathobject tocfg.bypass.experiment_dir, butOmegaConf/DictConfigdoesn't natively supportPathobjects. This works because OmegaConf stores it as-is in struct mode off, but it may cause serialization issues (e.g.,json_dumpinsave_bypass_checkpoint). Consider converting tostr.
Suggestions:
- The
_copy_auto_map_code_filesaddition incheckpoint_utils_hf.pyis a good fix for trust_remote_code models. Consider adding a brief unit test or at least a comment about which models require this (e.g., NemotronH). - The
format_stitched_lossesNaN filtering is a nice quality-of-life improvement for hybrid models. Theimport mathinside the function body should be moved to the module top-level. - The
dataset.pychat_template fallback is correct and handles base models gracefully. - The
child_init.pyfix (return iteminstead ofreturn item_overrideswhenNone) is a real bug fix — good catch.
Overall Assessment: This is a well-structured, substantial feature addition. The core architecture (stitched model factory, per-block KD, pipeline integration) is sound. However, the hardcoded trust_remote_code=True security issue and the duplicated normalized_mse_loss need to be addressed before merge. The torch.load calls should also use weights_only=True per project convention.
- Fix realize_best_or_latest: add find_best_run_dir() and update realize_bypass_checkpoints() to honor the config field (was always using the latest checkpoint regardless of the setting) - Improve experiment ID generation: replace hard-coded parsing logic with a config-driven spec table (_OVERRIDE_COMPONENT_SPECS) that handles FFN, MoE, GQA, and Mamba in a unified way; fix None values being included in IDs (e.g. bypass_ffn_None_heads_4 → bypass_kv4); new format: bypass_ffn256_kv4, bypass_experts4, bypass_mamba, etc. - Simplify checkpoint resume: replace wasteful state-dict dict-merge with load_state_dict(strict=False); add weights_only=True to all torch.load() calls - Refactor stitched_model_factory: extract _initialize_student_model() helper to reduce bypass_factory_fn from ~250 to ~100 lines; document the block_loss_func keyword-argument contract (input=, target=) - Add find_best_run_dir to checkpoint_utils; add NemotronH example to _copy_auto_map_code_files docstring - Tests: add GPU test for checkpoint resume (find_last_ckpt_for_resume path); add unit tests for _set_keys_to_learn (all branches including hybrid Mamba/GQA filtering) and set_experiment_id (11 cases) - Fix ruff N806 in main.py (N → n); fix PT006 in test_bypass_utils.py; update copyright year to 2026 on all new bypass files
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
811-843:⚠️ Potential issue | 🟠 Major
nulloverrides still crash for optional nested dataclasses.The new
item_overrides is Nonebranch is bypassed whenprevious_value is None and _is_dataclass_type(item_type)is true, so an override like...: nullstill becomes_get_dataclass_type(item_type)(**item_overrides)and raises at runtime. This is easy to hit for optional sub-configs that default toNone.Suggested fix
- if previous_value is None and _is_dataclass_type(item_type): - new_value = _get_dataclass_type(item_type)(**item_overrides) + if item_overrides is None: + new_value = previous_value + elif previous_value is None and _is_dataclass_type(item_type): + assert isinstance(item_overrides, dict) + new_value = _get_dataclass_type(item_type)(**item_overrides) else: new_value = override(previous_value, item_overrides)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 811 - 843, The dataclass_override loop special-case instantiates a nested dataclass even when the provided override is None, causing a crash; modify the block in dataclass_override that handles "previous_value is None and _is_dataclass_type(item_type)" to first check if item_overrides is None and in that case set new_value = None (or call override(previous_value, item_overrides)), otherwise instantiate with _get_dataclass_type(item_type)(**item_overrides); keep the subsequent check_type(new_value, item_type) and existing symbols (override, dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so optional nested dataclass overrides that are null no longer raise.
🧹 Nitpick comments (3)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
96-115: Reject overlapping outputs from multiple pruning mixins.
layer_out_state_dict.update(_layer_out)silently makes the final checkpoint depend on mixin order if two mixins emit the same state-dict key. Failing fast here is safer than letting one mixin overwrite the other.Suggested guard
- layer_out_state_dict.update(_layer_out) + overlapping_keys = layer_out_state_dict.keys() & _layer_out.keys() + if overlapping_keys: + raise ValueError( + f"Pruning mixins produced overlapping keys for layer {layer_idx}: " + f"{sorted(overlapping_keys)}" + ) + layer_out_state_dict.update(_layer_out)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 96 - 115, The loop over pruning mixins currently does layer_out_state_dict.update(_layer_out) which allows later mixins to silently overwrite keys from earlier ones; change this to detect overlapping keys and fail fast: for each _mixin when you get _layer_out from prune_single_layer, compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys()) and if intersection is non-empty raise a ValueError (or AssertionError) listing the conflicting keys and the mixin identity (use _mixin or its type) instead of updating; only call layer_out_state_dict.update(_layer_out) when intersection is empty to ensure deterministic, non-overlapping outputs from prune_single_layer across mixins.modelopt/torch/puzzletron/tools/kd_model.py (1)
38-39: Add a zero/near-zero target regression test for this denominator change.This adjustment mainly changes behavior when
targethas tiny norm, buttests/unit/torch/puzzletron/test_bypass_losses.pycurrently only covers random tensors. A focused zero-target case would keep this stabilization behavior from regressing unnoticed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/kd_model.py` around lines 38 - 39, Add a unit test in tests/unit/torch/puzzletron/test_bypass_losses.py (e.g., test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and a near-zero target tensor, calls the code path that computes loss using the expression containing F.mse_loss(input, target, reduction=reduction) / (F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon), and asserts the loss is finite and behaves stably (no division-by-zero, not NaN/Inf) for both cases; use the same input tensor for both and check that adding the epsilon in the denominator prevents regressions.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
377-381: Consider more descriptive error handling for invalidblock_loss_func.If
cfg.model_factory.block_loss_funcis not one of the three supported values, aKeyErroris raised with just the invalid key name. A more descriptive error would help users identify the misconfiguration quickly.Suggested improvement
+ _BLOCK_LOSS_FUNCS = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + } + loss_func_name = cfg.model_factory.block_loss_func + if loss_func_name not in _BLOCK_LOSS_FUNCS: + raise ValueError( + f"Unknown block_loss_func '{loss_func_name}'. " + f"Supported: {list(_BLOCK_LOSS_FUNCS.keys())}" + ) - block_loss_func = { - "normalized_mse_loss": normalized_mse_loss, - "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, - "batched_normalized_mse_loss": batched_normalized_mse_loss, - }[cfg.model_factory.block_loss_func] + block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 377 - 381, The current lookup for block_loss_func in stitched_model_factory.py uses a direct dict index which raises an opaque KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a guarded lookup: retrieve via dict.get or check membership first and raise a ValueError with a clear message that includes the invalid value and the allowed options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss", "batched_normalized_mse_loss"); update the code around the block_loss_func assignment (the dict and its use) so callers get a descriptive error instead of a raw KeyError.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 113-125: The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 597-600: Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
- Around line 417-428: The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 837-838: The code reads source_datasets_to_discard from cfg.bypass
root but the new config nests it under bypass.data; update the dataloader calls
that set source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.
- Around line 252-253: The parameter skip_first_batches is never applied: after
creating the batch iterator from the ConstantLengthDataset/dataloader you must
advance that iterator by skip_first_batches before entering the training loop
(e.g., consume the iterator with next(...) in a short loop or use
itertools.islice to drop the first N items); update the code paths where
skip_first_batches is accepted (the occurrences around skip_first_batches in
training_loop.py and the second occurrence at lines ~329-330) to consume the
iterator accordingly so resumed runs do not replay from batch 0.
- Around line 349-350: The loop exit condition uses a 1-based counter and
currently uses >=, causing it to stop one step too early; update the check in
the training loop that references cfg.bypass.step_num and
cfg.bypass.training.max_steps so it breaks only once step_num has passed the
budget (use > instead of >=) so the final scheduled step runs.
- Around line 103-107: The AutoConfig.from_pretrained call inside
run_bypassed_training bypasses the earlier trust_remote_code decision; update
the AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 209-215: The test incorrectly computes per-rank FFN counts from
hidden layer count (total_layers = max(2, size)); instead compute the actual
number of prunable FFN blocks (e.g., scan the model's layer names or modules to
count FFN/prunable blocks rather than using hidden-layer count) into
total_ffn_blocks, then compute layers_this_rank = total_ffn_blocks // size + (1
if rank < total_ffn_blocks % size else 0) and assert len(layer_names) ==
layers_this_rank (allowing 0 for ranks that only own Mamba blocks); update the
variables total_layers/layers_this_rank and reference layer_names when making
this change.
---
Outside diff comments:
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 811-843: The dataclass_override loop special-case instantiates a
nested dataclass even when the provided override is None, causing a crash;
modify the block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 96-115: The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.
In `@modelopt/torch/puzzletron/tools/kd_model.py`:
- Around line 38-39: Add a unit test in
tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6cb35ef5-41ea-4f6d-990a-791e2c99b812
📒 Files selected for processing (90)
examples/puzzletron/BYPASS.mdexamples/puzzletron/README.mdexamples/puzzletron/configs/bypass/defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yamlexamples/puzzletron/configs/pruning/defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/realize_model/validate_solutions_defaults.yamlexamples/puzzletron/configs/scoring/validate_solutions_defaults.yamlexamples/puzzletron/configs/validate_model_defaults.yamlexamples/puzzletron/configs/validate_solutions_defaults.yamlexamples/puzzletron/main.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/dataset/prepare_dataset.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/tools/kd_model.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/parsing.pymodelopt/torch/utils/plugins/transformers_dataset.pytests/_test_utils/torch/puzzletron/utils.pytests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.pytests/gpu/torch/puzzletron/nas/plugins/test_nas_search.pytests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yamltests/gpu/torch/puzzletron/test_bypass.pytests/gpu/torch/puzzletron/test_puzzletron.pytests/unit/torch/puzzletron/__init__.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_bypass_utils.py
✅ Files skipped from review due to trivial changes (67)
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
- tests/unit/torch/puzzletron/init.py
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
- modelopt/torch/puzzletron/dataset/prepare_dataset.py
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
- modelopt/torch/utils/plugins/transformers_dataset.py
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
- tests/_test_utils/torch/puzzletron/utils.py
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
- tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
- examples/puzzletron/configs/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
- examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/README.md
- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
- examples/puzzletron/configs/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
- modelopt/torch/puzzletron/bypass_distillation/data_classes.py
- modelopt/torch/puzzletron/bypass_distillation/init.py
- examples/puzzletron/configs/bypass/defaults.yaml
- examples/puzzletron/BYPASS.md
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
- tests/unit/torch/puzzletron/test_bypass_losses.py
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
🚧 Files skipped from review as they are similar to previous changes (6)
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
- modelopt/torch/puzzletron/utils/parsing.py
- modelopt/torch/puzzletron/pruning/pruning_utils.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
- modelopt/torch/puzzletron/sewing_kit/utils.py
- modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
| if optimizer is not None: | ||
| optimizer_state_path = ( | ||
| load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" | ||
| ) | ||
| if verbose: | ||
| mprint( | ||
| f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" | ||
| ) | ||
| loaded_optimizer_state = torch.load( | ||
| optimizer_state_path, map_location=device, weights_only=True | ||
| ) | ||
| optimizer.load_state_dict(loaded_optimizer_state) | ||
| del loaded_optimizer_state |
There was a problem hiding this comment.
Persist GradScaler state as part of the bypass checkpoint.
StitchedModuleDescriptor includes grad_scaler, but the checkpoint only saves/restores model and optimizer state. With use_grad_scaling=True, a resumed run restarts from a fresh scale factor instead of the checkpointed training state.
💡 Suggested fix
if optimizer is not None:
optimizer_state_path = (
load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth"
)
if verbose:
mprint(
f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}"
)
loaded_optimizer_state = torch.load(
optimizer_state_path, map_location=device, weights_only=True
)
optimizer.load_state_dict(loaded_optimizer_state)
del loaded_optimizer_state
+
+ grad_scaler = stitched_module_descriptor.grad_scaler
+ if grad_scaler is not None:
+ scaler_state_path = load_dir / "stitched" / f"{stitched_module_name}.grad_scaler_state.pth"
+ loaded_scaler_state = torch.load(
+ scaler_state_path, map_location=device, weights_only=True
+ )
+ grad_scaler.load_state_dict(loaded_scaler_state)
+ del loaded_scaler_state
...
if optimizer is not None:
optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth"
if verbose:
mprint(
f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}"
)
_save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite)
+
+ grad_scaler = stitched_module_descriptor.grad_scaler
+ if grad_scaler is not None:
+ scaler_state_path = save_dir / f"{stitched_module_name}.grad_scaler_state.pth"
+ _save_local_file(grad_scaler.state_dict(), scaler_state_path, overwrite=overwrite)Also applies to: 165-171
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 113 - 125, The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.
| min_owned_index = min(owned_block_indexes) | ||
| max_owned_index = max(owned_block_indexes) | ||
| prev_rank: Optional[int] = ( | ||
| None | ||
| if min_owned_index == min(all_block_indices) | ||
| else model_blocks_process_ownership[min_owned_index - 1] | ||
| ) | ||
| next_rank: Optional[int] = ( | ||
| None | ||
| if max_owned_index == max(all_block_indices) | ||
| else model_blocks_process_ownership[max_owned_index + 1] | ||
| ) |
There was a problem hiding this comment.
Potential ValueError if a rank owns no blocks.
min(owned_block_indexes) and max(owned_block_indexes) will raise ValueError if owned_block_indexes is empty. While the current design likely ensures every rank owns at least one block, defensive handling would prevent cryptic errors during misconfiguration.
Suggested defensive check
+ if not owned_block_indexes:
+ raise ValueError(
+ f"Rank {dist.rank()} owns no blocks. Check model_blocks_process_ownership mapping."
+ )
+
min_owned_index = min(owned_block_indexes)
max_owned_index = max(owned_block_indexes)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| min_owned_index = min(owned_block_indexes) | |
| max_owned_index = max(owned_block_indexes) | |
| prev_rank: Optional[int] = ( | |
| None | |
| if min_owned_index == min(all_block_indices) | |
| else model_blocks_process_ownership[min_owned_index - 1] | |
| ) | |
| next_rank: Optional[int] = ( | |
| None | |
| if max_owned_index == max(all_block_indices) | |
| else model_blocks_process_ownership[max_owned_index + 1] | |
| ) | |
| if not owned_block_indexes: | |
| raise ValueError( | |
| f"Rank {dist.rank()} owns no blocks. Check model_blocks_process_ownership mapping." | |
| ) | |
| min_owned_index = min(owned_block_indexes) | |
| max_owned_index = max(owned_block_indexes) | |
| prev_rank: Optional[int] = ( | |
| None | |
| if min_owned_index == min(all_block_indices) | |
| else model_blocks_process_ownership[min_owned_index - 1] | |
| ) | |
| next_rank: Optional[int] = ( | |
| None | |
| if max_owned_index == max(all_block_indices) | |
| else model_blocks_process_ownership[max_owned_index + 1] | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 417 - 428, The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.
| mprint( | ||
| f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors " | ||
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | ||
| ) |
There was a problem hiding this comment.
Log message will always show empty block name.
submodule_name is initialized to "" at line 449 and never reassigned within the loop. The log message "Block : ..." will always display an empty block name. Consider using student_stitched_module_name (e.g., block_0) or module_name for clarity.
Suggested fix
mprint(
- f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
+ f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| mprint( | |
| f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors " | |
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | |
| ) | |
| mprint( | |
| f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors " | |
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 597 - 600, Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
| if do_ffn or do_attn or do_blk: | ||
| from transformers import AutoConfig as HFAutoConfig | ||
|
|
||
| teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir)) | ||
| teacher_intermediate_size = getattr(teacher_hf_cfg, "intermediate_size", None) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '95,120p'Repository: NVIDIA/Model-Optimizer
Length of output: 1567
🏁 Script executed:
rg "requires_trust_remote_code|trust_remote_code" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 614
🏁 Script executed:
rg "ModelDescriptorFactory" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 539
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '1,110p'Repository: NVIDIA/Model-Optimizer
Length of output: 5786
🏁 Script executed:
rg "def requires_trust_remote_code" modelopt/ -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1441
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '73,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 9638
🏁 Script executed:
rg "def launch_bypass_distillation" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 50Repository: NVIDIA/Model-Optimizer
Length of output: 2748
🏁 Script executed:
rg "hydra_cfg.descriptor|cfg.descriptor" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 344
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '240,300p'Repository: NVIDIA/Model-Optimizer
Length of output: 2827
🏁 Script executed:
rg "descriptor\s*=" modelopt/torch/puzzletron/bypass_distillation/training_loop.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 558
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '290,360p'Repository: NVIDIA/Model-Optimizer
Length of output: 3106
🏁 Script executed:
rg "def run_bypassed_training" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 30Repository: NVIDIA/Model-Optimizer
Length of output: 1259
Thread trust_remote_code through the auto-config probe.
run_bypassed_training() queries the descriptor for trust_remote_code, but this auto-config path at lines 103–107 bypasses that and calls AutoConfig.from_pretrained() with default (unsafe) behavior. Models requiring remote code execution will fail inconsistently depending on which path loads them.
💡 Suggested fix
if do_ffn or do_attn or do_blk:
from transformers import AutoConfig as HFAutoConfig
- teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
+ descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ teacher_hf_cfg = HFAutoConfig.from_pretrained(
+ str(hydra_cfg.teacher_dir),
+ trust_remote_code=trust_remote_code,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
103 - 107, The AutoConfig.from_pretrained call inside run_bypassed_training
bypasses the earlier trust_remote_code decision; update the
AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.
| skip_first_batches: int = 0, | ||
| tokenizer: Optional[PreTrainedTokenizerBase] = None, |
There was a problem hiding this comment.
skip_first_batches is currently a no-op.
The iterator is created and consumed immediately, but never advanced by skip_first_batches. On resume that replays the training stream from batch 0, because ConstantLengthDataset does not persist iterator position.
💡 Suggested fix
train_iterator = iter(train_dataloader)
+ if dist.is_master() and skip_first_batches:
+ for _ in range(skip_first_batches):
+ next(train_iterator)
mprint("Waiting for everyone before training starts")Also applies to: 329-330
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
252 - 253, The parameter skip_first_batches is never applied: after creating the
batch iterator from the ConstantLengthDataset/dataloader you must advance that
iterator by skip_first_batches before entering the training loop (e.g., consume
the iterator with next(...) in a short loop or use itertools.islice to drop the
first N items); update the code paths where skip_first_batches is accepted (the
occurrences around skip_first_batches in training_loop.py and the second
occurrence at lines ~329-330) to consume the iterator accordingly so resumed
runs do not replay from batch 0.
| if cfg.bypass.step_num >= cfg.bypass.training.max_steps: | ||
| if ( |
There was a problem hiding this comment.
Stop after the last scheduled step, not before it.
With the current 1-based step_num, max_steps=1 exits before the first optimizer step and max_steps=2 only executes one step. This should break once step_num has moved past the budget, not when it is equal to it.
💡 Suggested fix
- if cfg.bypass.step_num >= cfg.bypass.training.max_steps:
+ if cfg.bypass.step_num > cfg.bypass.training.max_steps:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if cfg.bypass.step_num >= cfg.bypass.training.max_steps: | |
| if ( | |
| if cfg.bypass.step_num > cfg.bypass.training.max_steps: | |
| if ( |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
349 - 350, The loop exit condition uses a 1-based counter and currently uses >=,
causing it to stop one step too early; update the check in the training loop
that references cfg.bypass.step_num and cfg.bypass.training.max_steps so it
breaks only once step_num has passed the budget (use > instead of >=) so the
final scheduled step runs.
| source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), | ||
| bos_rate=cfg.bypass.data.bos_rate, |
There was a problem hiding this comment.
Read source_datasets_to_discard from bypass.data.
The new bypass config nests this field under bypass.data, but both dataloader calls read from the bypass root. As written, the discard list is effectively impossible to configure.
💡 Suggested fix
- source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()),
+ source_datasets_to_discard=cfg.bypass.data.get(
+ "source_datasets_to_discard", tuple()
+ ),
...
- source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()),
+ source_datasets_to_discard=cfg.bypass.data.get(
+ "source_datasets_to_discard", tuple()
+ ),Also applies to: 858-859
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
837 - 838, The code reads source_datasets_to_discard from cfg.bypass root but
the new config nests it under bypass.data; update the dataloader calls that set
source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.
| # The test model has num_hidden_layers = max(2, size), so every rank owns at least | ||
| # one layer. Compute the actual expected count for *this* rank. | ||
| total_layers = max(2, size) | ||
| layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0) | ||
| assert len(layer_names) == layers_this_rank, ( | ||
| f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" | ||
| ) |
There was a problem hiding this comment.
The per-rank FFN count is still wrong for hybrid models.
total_layers = max(2, size) counts hidden layers, not prunable FFN blocks. This file already documents nvidia/NVIDIA-Nemotron-Nano-12B-v2 as having only one FFN layer, so a rank that owns only Mamba blocks can legitimately have len(layer_names) == 0.
💡 Suggested fix
- total_layers = max(2, size)
- layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
- assert len(layer_names) == layers_this_rank, (
- f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
- )
+ total_layers = max(2, size)
+ if len(expected) == total_layers:
+ layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
+ assert len(layer_names) == layers_this_rank, (
+ f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 209 - 215, The
test incorrectly computes per-rank FFN counts from hidden layer count
(total_layers = max(2, size)); instead compute the actual number of prunable FFN
blocks (e.g., scan the model's layer names or modules to count FFN/prunable
blocks rather than using hidden-layer count) into total_ffn_blocks, then compute
layers_this_rank = total_ffn_blocks // size + (1 if rank < total_ffn_blocks %
size else 0) and assert len(layer_names) == layers_this_rank (allowing 0 for
ranks that only own Mamba blocks); update the variables
total_layers/layers_this_rank and reference layer_names when making this change.
- Extract common setup preamble (dist.setup, register_hydra_resolvers, hydra config load, _total_steps) into _setup() helper in main.py to eliminate duplication between run_full_puzzletron and run_mip_only - Rename uppercase N → n in main.py and puzzletron_nas_plugin.py - Remove unused gqa_factory_fn and moe_factory_fn aliases from stitched_model_factory.py - Improve BYPASS.md: clarify when to run bypass (KV head reduction, no_op blocks, extreme FFN/MoE compression); fix coupled BLD cost description (N×M runs vs N+M, not harder to optimise)
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/puzzletron/main.py (1)
102-135:⚠️ Potential issue | 🟠 MajorEnsure distributed cleanup runs on failure paths.
If convert/search/sweep/MIP raises,
dist.cleanup()is skipped. In multi-GPU flows this can leave process groups hanging. Wrap execution intry/finallyin both runners.Proposed fix
def run_full_puzzletron(hydra_config_path: str): @@ hydra_cfg, hydra_config_dir, hydra_config_name, n = _setup(hydra_config_path) mprint(f"Puzzletron Progress 1/{n}: starting puzzletron pipeline") - - # Convert model (convert from HF to DeciLM, score pruning activations, - # prune the model and save pruned checkpoints) - input_model = PuzzletronModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "puzzletron", - { - "puzzle_dir": str(hydra_cfg.puzzle_dir), - "input_model_path": hydra_cfg.input_hf_model_path, - "hydra_config_dir": hydra_config_dir, - "hydra_config_name": hydra_config_name, - "dataset_path": str(hydra_cfg.dataset_path), - }, - ) - ], - ) - - # Run NAS search (build replacement library and compute stats, - # compute one block scores, run MIP and realize models) - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) - - dist.cleanup() + try: + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + finally: + dist.cleanup() mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)") @@ def run_mip_only(hydra_config_path: str): @@ - # Check if sweep mode is enabled - if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): - mprint( - f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)" - ) - sweep.run_mip_sweep(hydra_cfg) - else: - # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) - - dist.cleanup() + try: + # Check if sweep mode is enabled + if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): + mprint( + f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)" + ) + sweep.run_mip_sweep(hydra_cfg) + else: + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + finally: + dist.cleanup() mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")Also applies to: 147-163
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/puzzletron/main.py` around lines 102 - 135, The current flow calls dist.cleanup() after running mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline (from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so dist.cleanup() always runs, and apply the same try/finally pattern to the other runner block referenced around lines 147-163; ensure the try encompasses all work that requires the distributed group and the finally calls dist.cleanup() unconditionally.
🧹 Nitpick comments (1)
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
377-381: Consider clearer error handling for unknownblock_loss_func.If
cfg.model_factory.block_loss_funcis not one of the three expected values, aKeyErroris raised with a cryptic message. A more informative error would help users diagnose configuration issues.♻️ Suggested improvement
- block_loss_func = { + _BLOCK_LOSS_FUNCS = { "normalized_mse_loss": normalized_mse_loss, "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, "batched_normalized_mse_loss": batched_normalized_mse_loss, - }[cfg.model_factory.block_loss_func] + } + loss_func_name = cfg.model_factory.block_loss_func + if loss_func_name not in _BLOCK_LOSS_FUNCS: + raise ValueError( + f"Unknown block_loss_func '{loss_func_name}'. " + f"Expected one of: {list(_BLOCK_LOSS_FUNCS.keys())}" + ) + block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 377 - 381, The current lookup for block_loss_func using a dict keyed by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the code around block_loss_func (in stitched_model_factory.py) to explicitly validate cfg.model_factory.block_loss_func against the allowed names ("normalized_mse_loss", "vectorwise_normalized_mse_loss", "batched_normalized_mse_loss") and raise a clear ValueError that includes the invalid value and the list of valid options; reference the existing functions normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss when constructing the mapping and error message so users can quickly see the supported choices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/puzzletron/main.py`:
- Around line 99-101: The docstring parameter name is incorrect — replace the
documented `config_path` with the actual function parameter `hydra_config_path`
and update its description to match (e.g., "Path to the YAML configuration
file") so the `hydra_config_path` argument in the function signature and the
docstring are consistent; locate the docstring in examples/puzzletron/main.py
near the function that accepts `hydra_config_path` and make this single-name
correction.
---
Outside diff comments:
In `@examples/puzzletron/main.py`:
- Around line 102-135: The current flow calls dist.cleanup() after running
mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent
step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline
(from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create
PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so
dist.cleanup() always runs, and apply the same try/finally pattern to the other
runner block referenced around lines 147-163; ensure the try encompasses all
work that requires the distributed group and the finally calls dist.cleanup()
unconditionally.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func using a dict keyed
by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the
code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99826f98-f6fb-41d2-a78c-5c40bec6c4c9
📒 Files selected for processing (3)
examples/puzzletron/main.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
| Args: | ||
| config_path: Path to the YAML configuration file | ||
| """ |
There was a problem hiding this comment.
Fix docstring argument name mismatch.
Line 100 documents config_path, but the function argument is hydra_config_path. Please align the docstring to avoid confusion.
Proposed fix
def run_full_puzzletron(hydra_config_path: str):
@@
Args:
- config_path: Path to the YAML configuration file
+ hydra_config_path: Path to the YAML configuration file📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Args: | |
| config_path: Path to the YAML configuration file | |
| """ | |
| Args: | |
| hydra_config_path: Path to the YAML configuration file | |
| """ |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/puzzletron/main.py` around lines 99 - 101, The docstring parameter
name is incorrect — replace the documented `config_path` with the actual
function parameter `hydra_config_path` and update its description to match
(e.g., "Path to the YAML configuration file") so the `hydra_config_path`
argument in the function signature and the docstring are consistent; locate the
docstring in examples/puzzletron/main.py near the function that accepts
`hydra_config_path` and make this single-name correction.
Extract four self-contained blocks from the 436-line train() function into named helpers, reducing it to ~290 lines: - _save_final_checkpoint(): saves the final checkpoint when max_steps is reached and cleans up old iter-* checkpoints - _log_training_stats(): master-only block that processes loss history in log_interval chunks, updates best-loss tracking, prints tables via format_stitched_losses, and optionally logs to W&B - _run_validation(): runs the distributed validation pipeline, broadcasts val_loss from the last rank, and saves the best checkpoint if validation loss improved - _save_interval_checkpoint(): handles step-interval and time-based checkpoint saving, including kill_after_first_save semantics No behavioral changes — pure mechanical extraction.
|
@cjluo-nv addressed all the points (thanks again for the great review) |
Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.
Changes:
Summary by CodeRabbit
New Features
Documentation
Tests